{
 "cells": [
  {
   "cell_type": "raw",
   "id": "529aeba9",
   "metadata": {},
   "source": [
    "---\n",
    "sidebar_label: Fireworks\n",
    "---"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "642fd21c-600a-47a1-be96-6e1438b421a9",
   "metadata": {},
   "source": [
    "# ChatFireworks\n",
    "\n",
    ">[Fireworks](https://app.fireworks.ai/) accelerates product development on generative AI by creating an innovative AI experiment and production platform. \n",
    "\n",
    "This example goes over how to use LangChain to interact with `ChatFireworks` models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d00d850917865298",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from langchain.chat_models.fireworks import ChatFireworks\n",
    "from langchain.schema import HumanMessage, SystemMessage"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f28ebf8b-f14f-46c7-9962-8b8dc42e31be",
   "metadata": {},
   "source": [
    "# Setup\n",
    "\n",
    "1. Make sure the `fireworks-ai` package is installed in your environment.\n",
    "2. Sign in to [Fireworks AI](http://fireworks.ai) for the an API Key to access our models, and make sure it is set as the `FIREWORKS_API_KEY` environment variable.\n",
    "3. Set up your model using a model id. If the model is not set, the default model is fireworks-llama-v2-7b-chat. See the full, most up-to-date model list on [app.fireworks.ai](https://app.fireworks.ai)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d096fb14-8acc-4047-9cd0-c842430c3a1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "if \"FIREWORKS_API_KEY\" not in os.environ:\n",
    "    os.environ[\"FIREWORKS_API_KEY\"] = getpass.getpass(\"Fireworks API Key:\")\n",
    "\n",
    "# Initialize a Fireworks chat model\n",
    "chat = ChatFireworks(model=\"accounts/fireworks/models/llama-v2-13b-chat\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8f13144-37cf-47a5-b5a0-e3cdf76d9a72",
   "metadata": {},
   "source": [
    "# Calling the Model Directly\n",
    "\n",
    "You can call the model directly with a system and human message to get answers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "72340871-ae2f-415f-b399-0777d32dc379",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\"Hello! My name is LLaMA, I'm a large language model trained by a team of researcher at Meta AI. My primary function is to assist and converse with users like you, answering questions and engaging in discussion to the best of my ability. I'm here to help and provide information on a wide range of topics, so feel free to ask me anything!\", additional_kwargs={}, example=False)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ChatFireworks Wrapper\n",
    "system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
    "human_message = HumanMessage(content=\"Who are you?\")\n",
    "\n",
    "chat([system_message, human_message])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "68c6b1fa-2ff7-4a63-8d88-3cec302180b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\"Oh hello there! *giggle* It's such a beautiful day today, isn\", additional_kwargs={}, example=False)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Setting additional parameters: temperature, max_tokens, top_p\n",
    "chat = ChatFireworks(\n",
    "    model=\"accounts/fireworks/models/llama-v2-13b-chat\",\n",
    "    model_kwargs={\"temperature\": 1, \"max_tokens\": 20, \"top_p\": 1},\n",
    ")\n",
    "system_message = SystemMessage(content=\"You are to chat with the user.\")\n",
    "human_message = HumanMessage(content=\"How's the weather today?\")\n",
    "chat([system_message, human_message])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d93aa186-39cf-4e1a-aa32-01ed31d43bc8",
   "metadata": {},
   "source": [
    "# Simple Chat Chain"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28763fbc",
   "metadata": {},
   "source": [
    "You can use chat models on fireworks, with system prompts and memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cbe29efc-37c3-4c83-8b84-b8bba1a1e589",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chat_models import ChatFireworks\n",
    "from langchain.memory import ConversationBufferMemory\n",
    "from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "\n",
    "llm = ChatFireworks(\n",
    "    model=\"accounts/fireworks/models/llama-v2-13b-chat\",\n",
    "    model_kwargs={\"temperature\": 0, \"max_tokens\": 64, \"top_p\": 1.0},\n",
    ")\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\"system\", \"You are a helpful chatbot that speaks like a pirate.\"),\n",
    "        MessagesPlaceholder(variable_name=\"history\"),\n",
    "        (\"human\", \"{input}\"),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02991e05-a38e-47d4-9ab3-7e630a8ead55",
   "metadata": {},
   "source": [
    "Initially, there is no chat memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e2fd186f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'history': []}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "memory = ConversationBufferMemory(return_messages=True)\n",
    "memory.load_memory_variables({})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bee461da",
   "metadata": {},
   "source": [
    "Create a simple chain with memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "86972e54",
   "metadata": {},
   "outputs": [],
   "source": [
    "chain = (\n",
    "    RunnablePassthrough.assign(\n",
    "        history=memory.load_memory_variables | (lambda x: x[\"history\"])\n",
    "    )\n",
    "    | prompt\n",
    "    | llm.bind(stop=[\"\\n\\n\"])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f48cb142",
   "metadata": {},
   "source": [
    "Run the chain with a simple question, expecting an answer aligned with the system message provided."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "db3ad5b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\"Ahoy there, me hearty! Yer a fine lookin' swashbuckler, I can see that! *adjusts eye patch* What be bringin' ye to these waters? Are ye here to plunder some booty or just to enjoy the sea breeze?\", additional_kwargs={}, example=False)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = {\"input\": \"hi im bob\"}\n",
    "response = chain.invoke(inputs)\n",
    "response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "338f4bae",
   "metadata": {},
   "source": [
    "Save the memory context, then read it back to inspect contents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "257eec01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'history': [HumanMessage(content='hi im bob', additional_kwargs={}, example=False),\n",
       "  AIMessage(content=\"Ahoy there, me hearty! Yer a fine lookin' swashbuckler, I can see that! *adjusts eye patch* What be bringin' ye to these waters? Are ye here to plunder some booty or just to enjoy the sea breeze?\", additional_kwargs={}, example=False)]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "memory.save_context(inputs, {\"output\": response.content})\n",
    "memory.load_memory_variables({})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08441347",
   "metadata": {},
   "source": [
    "Now as another question that requires use of the memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7f5f2820",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\"Arrrr, ye be askin' about yer name, eh? Well, me matey, I be knowin' ye as Bob, the scurvy dog! *winks* But if ye want me to call ye somethin' else, just let me know, and I\", additional_kwargs={}, example=False)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = {\"input\": \"whats my name\"}\n",
    "chain.invoke(inputs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
